rename mrope related function, params

This commit is contained in:
HimariO 2024-12-08 01:32:19 +08:00
parent ac2089c378
commit 12f17f754d
8 changed files with 27 additions and 27 deletions

View file

@ -1399,7 +1399,7 @@ llama-minicpmv-cli: examples/llava/minicpmv-cli.cpp \
$(OBJ_ALL) $(OBJ_ALL)
$(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual $(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual
llama-qwen2vl-cli: examples/llava/minicpmv-cli.cpp \ llama-qwen2vl-cli: examples/llava/qwen2vl-cli.cpp \
examples/llava/llava.cpp \ examples/llava/llava.cpp \
examples/llava/llava.h \ examples/llava/llava.h \
examples/llava/clip.cpp \ examples/llava/clip.cpp \

View file

@ -759,7 +759,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
if (ctx->has_qwen2vl_merger) { if (ctx->has_qwen2vl_merger) {
Q = ggml_mrope_ext( Q = ggml_rope_multi(
ctx0, Q, positions, nullptr, ctx0, Q, positions, nullptr,
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
} }
@ -772,7 +772,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
if (ctx->has_qwen2vl_merger) { if (ctx->has_qwen2vl_merger) {
K = ggml_mrope_ext( K = ggml_rope_multi(
ctx0, K, positions, nullptr, ctx0, K, positions, nullptr,
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
} }

View file

@ -1447,7 +1447,7 @@ extern "C" {
float beta_fast, float beta_fast,
float beta_slow); float beta_slow);
GGML_API struct ggml_tensor * ggml_mrope_ext( GGML_API struct ggml_tensor * ggml_rope_multi(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b, struct ggml_tensor * b,

View file

@ -9317,7 +9317,7 @@ static void ggml_compute_forward_rope_f32(
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
const bool is_vision = mode == GGML_ROPE_TYPE_VISION; const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) { if (is_mrope) {

View file

@ -114,7 +114,7 @@ static __global__ void rope_neox(
} }
template<typename T, bool has_ff> template<typename T, bool has_ff>
static __global__ void rope_mrope( static __global__ void rope_multi(
const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) { float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
@ -261,7 +261,7 @@ static void rope_neox_cuda(
} }
template<typename T> template<typename T>
static void rope_mrope_cuda( static void rope_multi_cuda(
const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) { float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0); GGML_ASSERT(ne0 % 2 == 0);
@ -272,12 +272,12 @@ static void rope_mrope_cuda(
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
rope_mrope<T, false><<<block_nums, block_dims, 0, stream>>>( rope_multi<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections theta_scale, freq_factors, sections
); );
} else { } else {
rope_mrope<T, true><<<block_nums, block_dims, 0, stream>>>( rope_multi<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections theta_scale, freq_factors, sections
); );
@ -339,20 +339,20 @@ static void rope_neox_cuda_f32(
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
} }
static void rope_mrope_cuda_f16( static void rope_multi_cuda_f16(
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) { ) {
rope_mrope_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); rope_multi_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
} }
static void rope_mrope_cuda_f32( static void rope_multi_cuda_f32(
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) { ) {
rope_mrope_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); rope_multi_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
} }
static void rope_vision_cuda_f16( static void rope_vision_cuda_f16(
@ -455,12 +455,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
} }
} else if (is_mrope && !is_vision) { } else if (is_mrope && !is_vision) {
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32) {
rope_mrope_cuda_f32( rope_multi_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, sections, stream attn_factor, corr_dims, freq_factors, sections, stream
); );
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16) {
rope_mrope_cuda_f16( rope_multi_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, sections, stream attn_factor, corr_dims, freq_factors, sections, stream
); );

View file

@ -3548,7 +3548,7 @@ struct ggml_tensor * ggml_rope(
); );
} }
struct ggml_tensor * ggml_mrope_ext( struct ggml_tensor * ggml_rope_multi(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b, struct ggml_tensor * b,
@ -3563,7 +3563,7 @@ struct ggml_tensor * ggml_mrope_ext(
float attn_factor, float attn_factor,
float beta_fast, float beta_fast,
float beta_slow) { float beta_slow) {
// Multimodal Rotary Position Embedding
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported"); GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(ggml_is_vector(b));

View file

@ -2436,7 +2436,7 @@ struct llama_hparams {
float rope_freq_scale_train; float rope_freq_scale_train;
uint32_t n_ctx_orig_yarn; uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul; float rope_yarn_log_mul;
std::array<int, 4> rope_mrope_sections; std::array<int, 4> rope_sections;
// for State Space Models // for State Space Models
uint32_t ssm_d_conv = 0; uint32_t ssm_d_conv = 0;
@ -2493,7 +2493,7 @@ struct llama_hparams {
if (this->rope_finetuned != other.rope_finetuned) return true; if (this->rope_finetuned != other.rope_finetuned) return true;
if (this->n_ctx_orig_yarn != other.n_ctx_orig_yarn) return true; if (this->n_ctx_orig_yarn != other.n_ctx_orig_yarn) return true;
if (this->rope_mrope_sections != other.rope_mrope_sections) return true; if (this->rope_sections != other.rope_sections) return true;
if (this->ssm_d_conv != other.ssm_d_conv) return true; if (this->ssm_d_conv != other.ssm_d_conv) return true;
if (this->ssm_d_inner != other.ssm_d_inner) return true; if (this->ssm_d_inner != other.ssm_d_inner) return true;
@ -5717,8 +5717,8 @@ static void llm_load_hparams(
} break; } break;
case LLM_ARCH_QWEN2VL: case LLM_ARCH_QWEN2VL:
{ {
std::fill(hparams.rope_mrope_sections.begin(), hparams.rope_mrope_sections.end(), 0); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_mrope_sections, 4, true); ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true);
} }
// fall through // fall through
case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2:
@ -12543,7 +12543,7 @@ struct llm_build_context {
// KQ_mask (mask for 1 head, it will be broadcasted to all heads) // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
int sections[4]; int sections[4];
std::copy(hparams.rope_mrope_sections.begin(), hparams.rope_mrope_sections.end(), sections); std::copy(hparams.rope_sections.begin(), hparams.rope_sections.end(), sections);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL; struct ggml_tensor * inpSA = inpL;
@ -12572,7 +12572,7 @@ struct llm_build_context {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il); cb(Vcur, "Vcur", il);
Qcur = ggml_mrope_ext( Qcur = ggml_rope_multi(
ctx0, ctx0,
ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
@ -12580,7 +12580,7 @@ struct llm_build_context {
); );
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
Kcur = ggml_mrope_ext( Kcur = ggml_rope_multi(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow ext_factor, attn_factor, beta_fast, beta_slow

View file

@ -193,20 +193,20 @@ int main(int /*argc*/, const char ** /*argv*/) {
// [[100, 101, 102, ..., 172], // [[100, 101, 102, ..., 172],
// [101, 102, 103, ..., 173], // [101, 102, 103, ..., 173],
// [102, 103, 104, ..., 174]] // [102, 103, 104, ..., 174]]
r0 = ggml_mrope_ext( r0 = ggml_rope_multi(
ctx0, x, p0, nullptr, ctx0, x, p0, nullptr,
n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1); n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
// [[-67, -67, -67, ..., -67] // [[-67, -67, -67, ..., -67]
// [-67, -67, -67, ..., -67] // [-67, -67, -67, ..., -67]
// [-67, -67, -67, ..., -67]] // [-67, -67, -67, ..., -67]]
r1 = ggml_mrope_ext( r1 = ggml_rope_multi(
ctx0, r0, p1, nullptr, ctx0, r0, p1, nullptr,
n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1); n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
// [[33, 34, 35, ..., 105] // [[33, 34, 35, ..., 105]
// [34, 35, 36, ..., 106] // [34, 35, 36, ..., 106]
// [35, 36, 37, ..., 107]] // [35, 36, 37, ..., 107]]
r2 = ggml_mrope_ext( r2 = ggml_rope_multi(
ctx0, x, p2, nullptr, ctx0, x, p2, nullptr,
n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1); n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
} }