llama : add Qwen2VL support + multimodal RoPE (#10361)
* Barebone Qwen2VL LLM convertor * Add Qwen2VL cli entrypoint * [WIP] add qwen2vl arch * Verify m-rope output * Add vl-rope/2d-rope support for qwen2vl ViT * update qwen2vl cli tool * update 5D tensor op workaround * [WIP] qwen2vl vision model * make batch and clip utils compatible with qwen2vl * [WIP] create inference workflow, gguf convert script but fix * correcting vision-rope behavior, add the missing last layer back to ViT * add arg parser to qwen2vl_surgery * replace variable size array with vector * cuda-gdb cmake preset * add fp32 mrope, vision rope kernel * add fp16 support for qwen2vl and m-rope * add `GGML_ROPE_TYPE_MROPE`, `GGML_ROPE_TYPE_VISION` * fix rope op mode switching, out dated func args * update `llama_hparams` * update to keep up stream changes * resolve linter, test errors * add makefile entry, update speical image padding token * add mrope unit test, fix few compiler warnings * rename `mrope` related function, params * minor updates on debug util, bug fixs * add `m-rope` testcase to `test-backend-ops` * Apply suggestions from code review Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * fix traililng whitespce * store `llama_hparams.rope_sections` with fixed size array * update position id tensor size check in GGML_OP_ROPE * minor updates * update `ggml_backend_*_supports_op` of unsupported backends * remote old `rope_section` compare operator --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
56eea0781c
commit
ba1cb19cdd
24 changed files with 1873 additions and 114 deletions
178
src/llama.cpp
178
src/llama.cpp
|
@ -163,6 +163,7 @@ enum llm_arch {
|
|||
LLM_ARCH_QWEN,
|
||||
LLM_ARCH_QWEN2,
|
||||
LLM_ARCH_QWEN2MOE,
|
||||
LLM_ARCH_QWEN2VL,
|
||||
LLM_ARCH_PHI2,
|
||||
LLM_ARCH_PHI3,
|
||||
LLM_ARCH_PLAMO,
|
||||
|
@ -217,6 +218,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_QWEN, "qwen" },
|
||||
{ LLM_ARCH_QWEN2, "qwen2" },
|
||||
{ LLM_ARCH_QWEN2MOE, "qwen2moe" },
|
||||
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
|
||||
{ LLM_ARCH_PHI2, "phi2" },
|
||||
{ LLM_ARCH_PHI3, "phi3" },
|
||||
{ LLM_ARCH_PLAMO, "plamo" },
|
||||
|
@ -308,6 +310,7 @@ enum llm_kv {
|
|||
LLM_KV_ATTENTION_SCALE,
|
||||
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||
LLM_KV_ROPE_FREQ_BASE,
|
||||
LLM_KV_ROPE_SCALE_LINEAR,
|
||||
LLM_KV_ROPE_SCALING_TYPE,
|
||||
|
@ -424,6 +427,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
||||
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
|
||||
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
|
||||
|
@ -898,6 +902,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_QWEN2VL,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_QWEN2MOE,
|
||||
{
|
||||
|
@ -2474,11 +2495,12 @@ struct llama_hparams {
|
|||
uint32_t time_decay_extra_dim = 0;
|
||||
uint32_t wkv_head_size = 0;
|
||||
|
||||
float rope_attn_factor = 1.0f;
|
||||
float rope_freq_base_train;
|
||||
float rope_freq_scale_train;
|
||||
uint32_t n_ctx_orig_yarn;
|
||||
float rope_yarn_log_mul;
|
||||
float rope_attn_factor = 1.0f;
|
||||
float rope_freq_base_train;
|
||||
float rope_freq_scale_train;
|
||||
uint32_t n_ctx_orig_yarn;
|
||||
float rope_yarn_log_mul;
|
||||
int rope_sections[4];
|
||||
|
||||
// for State Space Models
|
||||
uint32_t ssm_d_conv = 0;
|
||||
|
@ -2535,6 +2557,9 @@ struct llama_hparams {
|
|||
|
||||
if (this->rope_finetuned != other.rope_finetuned) return true;
|
||||
if (this->n_ctx_orig_yarn != other.n_ctx_orig_yarn) return true;
|
||||
if (std::equal(std::begin(this->rope_sections),
|
||||
std::end(this->rope_sections),
|
||||
std::begin(other.rope_sections))) return true;
|
||||
|
||||
if (this->ssm_d_conv != other.ssm_d_conv) return true;
|
||||
if (this->ssm_d_inner != other.ssm_d_inner) return true;
|
||||
|
@ -3378,6 +3403,11 @@ struct llama_context {
|
|||
// whether we are computing encoder output or decoder output
|
||||
bool is_encoding = false;
|
||||
|
||||
// TODO: find a better way to accommodate mutli-dimension position encoding methods
|
||||
// number of position id each token get, 1 for each token in most cases.
|
||||
// when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
|
||||
int n_pos_per_token = 1;
|
||||
|
||||
// output of the encoder part of the encoder-decoder models
|
||||
std::vector<float> embd_enc;
|
||||
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
||||
|
@ -5747,6 +5777,13 @@ static void llm_load_hparams(
|
|||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
{
|
||||
std::array<int, 4> section_dims;
|
||||
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, section_dims, 4, true);
|
||||
std::copy(section_dims.begin(), section_dims.begin() + 4, std::begin(hparams.rope_sections));
|
||||
}
|
||||
// fall through
|
||||
case LLM_ARCH_QWEN2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
@ -8167,6 +8204,7 @@ static bool llm_load_tensors(
|
|||
}
|
||||
} break;
|
||||
case LLM_ARCH_QWEN2:
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
{
|
||||
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
|
@ -12556,6 +12594,124 @@ struct llm_build_context {
|
|||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_qwen2vl() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens * 4);
|
||||
cb(lctx.inp_pos, "inp_pos", -1);
|
||||
ggml_set_input(lctx.inp_pos);
|
||||
struct ggml_tensor * inp_pos = lctx.inp_pos;
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
int sections[4];
|
||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_rope_multi(
|
||||
ctx0,
|
||||
ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_multi(
|
||||
ctx0,
|
||||
ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_qwen2moe() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
|
@ -16657,6 +16813,11 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
{
|
||||
result = llm.build_qwen2();
|
||||
} break;
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
{
|
||||
lctx.n_pos_per_token = 4;
|
||||
result = llm.build_qwen2vl();
|
||||
} break;
|
||||
case LLM_ARCH_QWEN2MOE:
|
||||
{
|
||||
result = llm.build_qwen2moe();
|
||||
|
@ -16875,8 +17036,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch)
|
|||
|
||||
if (ubatch.pos && lctx.inp_pos) {
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||
auto n_pos = lctx.n_pos_per_token;
|
||||
ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*n_pos*ggml_element_size(lctx.inp_pos));
|
||||
}
|
||||
|
||||
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
|
@ -20009,6 +20170,9 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|||
case LLM_ARCH_MINICPM3:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
return LLAMA_ROPE_TYPE_MROPE;
|
||||
|
||||
// all model arches should be listed explicitly here
|
||||
case LLM_ARCH_UNKNOWN:
|
||||
GGML_ABORT("unknown architecture");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue