rename mrope
related function, params
This commit is contained in:
parent
ac2089c378
commit
12f17f754d
8 changed files with 27 additions and 27 deletions
2
Makefile
2
Makefile
|
@ -1399,7 +1399,7 @@ llama-minicpmv-cli: examples/llava/minicpmv-cli.cpp \
|
||||||
$(OBJ_ALL)
|
$(OBJ_ALL)
|
||||||
$(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual
|
$(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual
|
||||||
|
|
||||||
llama-qwen2vl-cli: examples/llava/minicpmv-cli.cpp \
|
llama-qwen2vl-cli: examples/llava/qwen2vl-cli.cpp \
|
||||||
examples/llava/llava.cpp \
|
examples/llava/llava.cpp \
|
||||||
examples/llava/llava.h \
|
examples/llava/llava.h \
|
||||||
examples/llava/clip.cpp \
|
examples/llava/clip.cpp \
|
||||||
|
|
|
@ -759,7 +759,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
|
|
||||||
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
|
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
|
||||||
if (ctx->has_qwen2vl_merger) {
|
if (ctx->has_qwen2vl_merger) {
|
||||||
Q = ggml_mrope_ext(
|
Q = ggml_rope_multi(
|
||||||
ctx0, Q, positions, nullptr,
|
ctx0, Q, positions, nullptr,
|
||||||
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
||||||
}
|
}
|
||||||
|
@ -772,7 +772,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
|
|
||||||
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
|
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
|
||||||
if (ctx->has_qwen2vl_merger) {
|
if (ctx->has_qwen2vl_merger) {
|
||||||
K = ggml_mrope_ext(
|
K = ggml_rope_multi(
|
||||||
ctx0, K, positions, nullptr,
|
ctx0, K, positions, nullptr,
|
||||||
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1447,7 +1447,7 @@ extern "C" {
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow);
|
float beta_slow);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_mrope_ext(
|
GGML_API struct ggml_tensor * ggml_rope_multi(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
|
|
|
@ -9317,7 +9317,7 @@ static void ggml_compute_forward_rope_f32(
|
||||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
|
||||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||||
|
|
||||||
if (is_mrope) {
|
if (is_mrope) {
|
||||||
|
|
|
@ -114,7 +114,7 @@ static __global__ void rope_neox(
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, bool has_ff>
|
template<typename T, bool has_ff>
|
||||||
static __global__ void rope_mrope(
|
static __global__ void rope_multi(
|
||||||
const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||||
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
|
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
|
||||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||||
|
@ -261,7 +261,7 @@ static void rope_neox_cuda(
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void rope_mrope_cuda(
|
static void rope_multi_cuda(
|
||||||
const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
|
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ne0 % 2 == 0);
|
GGML_ASSERT(ne0 % 2 == 0);
|
||||||
|
@ -272,12 +272,12 @@ static void rope_mrope_cuda(
|
||||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||||
|
|
||||||
if (freq_factors == nullptr) {
|
if (freq_factors == nullptr) {
|
||||||
rope_mrope<T, false><<<block_nums, block_dims, 0, stream>>>(
|
rope_multi<T, false><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||||
theta_scale, freq_factors, sections
|
theta_scale, freq_factors, sections
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
rope_mrope<T, true><<<block_nums, block_dims, 0, stream>>>(
|
rope_multi<T, true><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||||
theta_scale, freq_factors, sections
|
theta_scale, freq_factors, sections
|
||||||
);
|
);
|
||||||
|
@ -339,20 +339,20 @@ static void rope_neox_cuda_f32(
|
||||||
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rope_mrope_cuda_f16(
|
static void rope_multi_cuda_f16(
|
||||||
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
||||||
) {
|
) {
|
||||||
|
|
||||||
rope_mrope_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
rope_multi_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rope_mrope_cuda_f32(
|
static void rope_multi_cuda_f32(
|
||||||
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
||||||
) {
|
) {
|
||||||
|
|
||||||
rope_mrope_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
rope_multi_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rope_vision_cuda_f16(
|
static void rope_vision_cuda_f16(
|
||||||
|
@ -455,12 +455,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
}
|
}
|
||||||
} else if (is_mrope && !is_vision) {
|
} else if (is_mrope && !is_vision) {
|
||||||
if (src0->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32) {
|
||||||
rope_mrope_cuda_f32(
|
rope_multi_cuda_f32(
|
||||||
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||||
attn_factor, corr_dims, freq_factors, sections, stream
|
attn_factor, corr_dims, freq_factors, sections, stream
|
||||||
);
|
);
|
||||||
} else if (src0->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16) {
|
||||||
rope_mrope_cuda_f16(
|
rope_multi_cuda_f16(
|
||||||
(const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
(const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||||
attn_factor, corr_dims, freq_factors, sections, stream
|
attn_factor, corr_dims, freq_factors, sections, stream
|
||||||
);
|
);
|
||||||
|
|
|
@ -3548,7 +3548,7 @@ struct ggml_tensor * ggml_rope(
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_mrope_ext(
|
struct ggml_tensor * ggml_rope_multi(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
|
@ -3563,7 +3563,7 @@ struct ggml_tensor * ggml_mrope_ext(
|
||||||
float attn_factor,
|
float attn_factor,
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow) {
|
float beta_slow) {
|
||||||
|
// Multimodal Rotary Position Embedding
|
||||||
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
|
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_vector(b));
|
GGML_ASSERT(ggml_is_vector(b));
|
||||||
|
|
|
@ -2436,7 +2436,7 @@ struct llama_hparams {
|
||||||
float rope_freq_scale_train;
|
float rope_freq_scale_train;
|
||||||
uint32_t n_ctx_orig_yarn;
|
uint32_t n_ctx_orig_yarn;
|
||||||
float rope_yarn_log_mul;
|
float rope_yarn_log_mul;
|
||||||
std::array<int, 4> rope_mrope_sections;
|
std::array<int, 4> rope_sections;
|
||||||
|
|
||||||
// for State Space Models
|
// for State Space Models
|
||||||
uint32_t ssm_d_conv = 0;
|
uint32_t ssm_d_conv = 0;
|
||||||
|
@ -2493,7 +2493,7 @@ struct llama_hparams {
|
||||||
|
|
||||||
if (this->rope_finetuned != other.rope_finetuned) return true;
|
if (this->rope_finetuned != other.rope_finetuned) return true;
|
||||||
if (this->n_ctx_orig_yarn != other.n_ctx_orig_yarn) return true;
|
if (this->n_ctx_orig_yarn != other.n_ctx_orig_yarn) return true;
|
||||||
if (this->rope_mrope_sections != other.rope_mrope_sections) return true;
|
if (this->rope_sections != other.rope_sections) return true;
|
||||||
|
|
||||||
if (this->ssm_d_conv != other.ssm_d_conv) return true;
|
if (this->ssm_d_conv != other.ssm_d_conv) return true;
|
||||||
if (this->ssm_d_inner != other.ssm_d_inner) return true;
|
if (this->ssm_d_inner != other.ssm_d_inner) return true;
|
||||||
|
@ -5717,8 +5717,8 @@ static void llm_load_hparams(
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_QWEN2VL:
|
case LLM_ARCH_QWEN2VL:
|
||||||
{
|
{
|
||||||
std::fill(hparams.rope_mrope_sections.begin(), hparams.rope_mrope_sections.end(), 0);
|
std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
|
||||||
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_mrope_sections, 4, true);
|
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true);
|
||||||
}
|
}
|
||||||
// fall through
|
// fall through
|
||||||
case LLM_ARCH_QWEN2:
|
case LLM_ARCH_QWEN2:
|
||||||
|
@ -12543,7 +12543,7 @@ struct llm_build_context {
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||||
int sections[4];
|
int sections[4];
|
||||||
std::copy(hparams.rope_mrope_sections.begin(), hparams.rope_mrope_sections.end(), sections);
|
std::copy(hparams.rope_sections.begin(), hparams.rope_sections.end(), sections);
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
struct ggml_tensor * inpSA = inpL;
|
struct ggml_tensor * inpSA = inpL;
|
||||||
|
@ -12572,7 +12572,7 @@ struct llm_build_context {
|
||||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||||
cb(Vcur, "Vcur", il);
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
Qcur = ggml_mrope_ext(
|
Qcur = ggml_rope_multi(
|
||||||
ctx0,
|
ctx0,
|
||||||
ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
@ -12580,7 +12580,7 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_mrope_ext(
|
Kcur = ggml_rope_multi(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
|
|
@ -193,20 +193,20 @@ int main(int /*argc*/, const char ** /*argv*/) {
|
||||||
// [[100, 101, 102, ..., 172],
|
// [[100, 101, 102, ..., 172],
|
||||||
// [101, 102, 103, ..., 173],
|
// [101, 102, 103, ..., 173],
|
||||||
// [102, 103, 104, ..., 174]]
|
// [102, 103, 104, ..., 174]]
|
||||||
r0 = ggml_mrope_ext(
|
r0 = ggml_rope_multi(
|
||||||
ctx0, x, p0, nullptr,
|
ctx0, x, p0, nullptr,
|
||||||
n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
|
n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
|
||||||
// [[-67, -67, -67, ..., -67]
|
// [[-67, -67, -67, ..., -67]
|
||||||
// [-67, -67, -67, ..., -67]
|
// [-67, -67, -67, ..., -67]
|
||||||
// [-67, -67, -67, ..., -67]]
|
// [-67, -67, -67, ..., -67]]
|
||||||
r1 = ggml_mrope_ext(
|
r1 = ggml_rope_multi(
|
||||||
ctx0, r0, p1, nullptr,
|
ctx0, r0, p1, nullptr,
|
||||||
n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
|
n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
|
||||||
|
|
||||||
// [[33, 34, 35, ..., 105]
|
// [[33, 34, 35, ..., 105]
|
||||||
// [34, 35, 36, ..., 106]
|
// [34, 35, 36, ..., 106]
|
||||||
// [35, 36, 37, ..., 107]]
|
// [35, 36, 37, ..., 107]]
|
||||||
r2 = ggml_mrope_ext(
|
r2 = ggml_rope_multi(
|
||||||
ctx0, x, p2, nullptr,
|
ctx0, x, p2, nullptr,
|
||||||
n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
|
n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue